import torch
import numpy as np
from typing import List, Dict, Tuple
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import random

class LSH:
    def __init__(self, embedding_dim: int, num_hash_tables: int = 8, num_hash_functions: int = 4, device: torch.device = None):
        """
        Initialize LSH with random projection vectors
        
        Args:
            embedding_dim: Dimension of the input embeddings
            num_hash_tables: Number of hash tables to use
            num_hash_functions: Number of hash functions per table
            device: Device to use for computations (torch.device)
        """
        self.embedding_dim = embedding_dim
        self.num_hash_tables = num_hash_tables
        self.num_hash_functions = num_hash_functions
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Initialize random projection vectors for each hash function
        self.projection_vectors = [
            torch.randn(embedding_dim, num_hash_functions, device=self.device)
            for _ in range(num_hash_tables)
        ]
        
        # Store the hash tables
        self.hash_tables: List[Dict[str, List[int]]] = [
            {} for _ in range(num_hash_tables)
        ]
        
    def _get_hash(self, embedding: torch.Tensor, table_idx: int) -> str:
        """Compute hash for a single embedding vector"""
        # Ensure embedding is on the correct device
        embedding = embedding.to(self.device)
        projections = torch.matmul(embedding, self.projection_vectors[table_idx])
        binary_hash = (projections > 0).int()
        return ''.join(map(str, binary_hash.tolist()))
    
    def add_embeddings(self, embeddings: torch.Tensor, indices: List[int]):
        """
        Add embeddings to the hash tables
        
        Args:
            embeddings: Tensor of shape (num_embeddings, embedding_dim)
            indices: List of indices corresponding to the embeddings
        """
        # Ensure embeddings are on the correct device
        embeddings = embeddings.to(self.device)
        for i, embedding in enumerate(embeddings):
            for table_idx in range(self.num_hash_tables):
                hash_val = self._get_hash(embedding, table_idx)
                if hash_val not in self.hash_tables[table_idx]:
                    self.hash_tables[table_idx][hash_val] = []
                self.hash_tables[table_idx][hash_val].append(indices[i])


    def query(self, query_embedding: torch.Tensor, top_k: int = 4000) -> List[int]:
        """
        Query the LSH tables to find similar embeddings and rank by collision count
        using efficient PyTorch operations
        
        Args:
            query_embedding: Query embedding vector
            top_k: Number of similar items to return
            
        Returns:
            List of indices of similar items, sorted by collision count
        """
        # Ensure query embedding is on the correct device
        query_embedding = query_embedding.to(self.device)
        
        # Create a tensor to hold all candidate indices
        all_candidates = []
        
        # Get hash values for the query embedding for all tables at once if possible
        # If _get_hash can be vectorized across tables, do that here
        
        # For each hash table, collect candidates
        for table_idx in range(self.num_hash_tables):
            hash_val = self._get_hash(query_embedding, table_idx)
            if hash_val in self.hash_tables[table_idx]:
                candidates = self.hash_tables[table_idx][hash_val]
                if isinstance(candidates, list):
                    # Convert to tensor if not already
                    candidates = torch.tensor(candidates, device=self.device)
                all_candidates.append(candidates)
        
        if not all_candidates:
            return []
        
        # Concatenate all candidates
        all_candidates = torch.cat(all_candidates)
        
        # Count occurrences using torch operations
        unique_candidates, counts = torch.unique(all_candidates, return_counts=True)
        
        # Get indices that would sort counts in descending order
        _, indices = torch.sort(counts, descending=True)
        
        # Sort unique candidates according to their counts
        sorted_candidates = unique_candidates[indices]

        # Take top_k and convert to list
        result = sorted_candidates[:top_k].cpu().tolist()
        
        return result

class GemmaLSH:
    def __init__(self, model, tokenizer, whitening=False, num_hash_tables=8, num_hash_functions=4):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)
        self.tokenizer = tokenizer
          
        # Get embedding dimension from model
        embedding_dim = self.model.get_output_embeddings().weight.shape[1]
        self.lsh = LSH(embedding_dim=embedding_dim, device=self.device, 
                      num_hash_tables=num_hash_tables, 
                      num_hash_functions=num_hash_functions)
        
        self.token_embeddings = self.model.get_output_embeddings().weight.detach().to(self.device)
        self.whitening = whitening

        if whitening:
            self.mean = self.token_embeddings.mean(axis=0)
            original_g_centered = self.token_embeddings - self.mean
            u, s, vt = torch.linalg.svd(original_g_centered, full_matrices=False)
            
            self.whitening_matrix = torch.matmul(
                    torch.matmul(vt.T, torch.diag(1.0 / torch.sqrt(s + 1e-6))),
                    vt
                )

            self.inverse_whitening_matrix = torch.matmul(
                    torch.matmul(vt.T, torch.diag(torch.sqrt(s))),
                    vt
                )
            self._build_lsh_index(u @ vt)
        else:
            self._build_lsh_index(self.token_embeddings)


        
    def _build_lsh_index(self, token_embeddings: torch.Tensor):
        """Build LSH index for all token embeddings"""
        indices = list(range(len(token_embeddings)))
        self.lsh.add_embeddings(token_embeddings, indices)
    
    def get_last_layer_embedding(self, prompt: str) -> torch.Tensor:
        """Get the last layer embedding for the last token"""
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            last_hidden_state = outputs.hidden_states[-1]
            last_token_embedding = last_hidden_state[0, -1, :]
        
        return last_token_embedding
    
    def predict_next_token(self, prompt: str, top_k: int = 10, 
                           top_candidates=4000, 
                           dual_transform=False, 
                           return_original=False,
                          return_logits=False) -> List[Tuple[str, float]]:
        """
        Predict next token using LSH-based retrieval
        
        Args:
            prompt: Input prompt
            top_k: Number of top predictions to return
            
        Returns:
            List of (token, probability) tuples
        """
        if return_original:
            query_embedding = self.get_last_layer_embedding(prompt)
            next_token_logits = torch.matmul(query_embedding,self.token_embeddings.T)
            if return_logits:
                return next_token_logits
            scores = F.softmax(next_token_logits, dim=-1)

            topk_scores, topk_indices = torch.topk(scores, k=top_k, dim=-1)  # both [top_k]

            # 4) Decode and pair
            results = []
            for idx, score in zip(topk_indices, topk_scores):
                token = self.tokenizer.decode([idx.item()])
                results.append((token, score.item()))
        
            return results
            
        else:
            # Get last token embedding
            query_embedding = self.get_last_layer_embedding(prompt)
            last_embedding = query_embedding

            if return_logits or dual_transform:
                next_token_logits = torch.matmul(query_embedding,self.token_embeddings.T)

            if dual_transform:
                next_token_probs = F.softmax(next_token_logits, dim=-1)
                query_embedding = torch.matmul(next_token_probs, self.token_embeddings)

            if self.whitening:
                if dual_transform:
                    query_embedding = torch.matmul(query_embedding - self.mean, self.whitening_matrix)
                else:
                    query_embedding = torch.matmul(query_embedding - self.mean, self.inverse_whitening_matrix)
        
            # Find similar tokens using LSH
            similar_indices = self.lsh.query(query_embedding, top_k=top_candidates)
            
            if return_logits:
                device = next_token_logits.device

                mask = torch.ones_like(next_token_logits, dtype=torch.bool)
                similar_indices_tensor = torch.tensor(similar_indices, device=device, dtype=torch.long)
                mask[similar_indices_tensor] = False                                     # False = keep, True = mask out
                
                # set masked positions to -inf
                next_token_logits = next_token_logits.masked_fill(mask, float("-inf"))
        
                return next_token_logits
            else:
            
                similar_embeddings = self.token_embeddings[similar_indices]
                    
                # Compute similarity scores
                scores = F.softmax(torch.matmul(last_embedding, similar_embeddings.T), dim=-1)
            
                # Get tokens and their probabilities
                results = []
                for idx, score in zip(similar_indices, scores):
                    token = self.tokenizer.decode([idx])
                    results.append((token, score.item()))
                
                # Sort and return only the top_k entries
                results = sorted(results, key=lambda x: x[1], reverse=True)[:top_k]
                return results
                
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_name = "google/gemma-7b"
    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir = "~/gemma_cache")
    model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir = "~/gemma_cache").to(device)


    gemma_lsh = GemmaLSH(model, tokenizer)
        
    # Test prompts
    test_prompts = [
        # "Once upon a time",
        "The quick brown fox",
        "In a world where",
        # "The capital of France is"
    ]
    
    # Test each prompt
    for prompt in test_prompts:
        print(f"\nPrompt: {prompt}")
        predictions = gemma_lsh.predict_next_token(prompt, return_original=True)
        print("Top predictions:")
        for token, prob in predictions:
            print(f"Token: {token}, Probability: {prob:.4f}")
            
        print("---------------")
        predictions = gemma_lsh.predict_next_token(prompt)
        print("Top predictions (LSH):")
        for token, prob in predictions:
            print(f"Token: {token}, Probability: {prob:.4f}")
            
        print("---------------")
        predictions = gemma_lsh.predict_next_token(prompt, dual_transform=True)
        print("Top predictions (LSH with dual):")
        for token, prob in predictions:
            print(f"Token: {token}, Probability: {prob:.4f}")


if __name__ == "__main__":
    main()
